{--

    Given a positive integer number _n_, find how many
    positive integer numbers that are smaller than _n_
    have the same digit-sum.
    
    The digit sum is obtained by taking the digits of a number 
    written in decimal notation as numbers and adding them.
    -}

module examples.DigitSum where



--- compute digit sum
digitSum :: Long -> Int
digitSum 0 = 0
digitSum n = Long.int (n `rem` 10) + digitSum (n `div` 10)

--- compute digit sum up to a limit
digitSumLimit :: Int -> Long -> Int
digitSumLimit !limit !n = go n 0
    where
        go 0 acc = acc
        go n !acc 
            | acc > limit = acc
            | otherwise = go (n `div` 10) (acc + Long.int (n `rem` 10))
            
slow n = go 1 0 (digitSum n)
    where 
        go !i !acc !limit
            | i >= n = acc
            | limit == digitSumLimit limit i = go (i+1) (acc+1) limit
            | otherwise = go (i+1) acc limit
        

-- slow result 222_222_222:   889830,  21,452 seconds
-- slow result 987_654_321: 39541589, 161.353 seconds
-- slowMain [] = println slows where slows = slow 222_222_222

--- Faster result
{--
    Find the smallest number that has a given digit sum,
    using k digits and some leading digits.
    -}
smallest k leading sum
    | needed sum < k = smallest (k-1) (0:leading) sum
    | sum >= 9       = 9 : smallest (k-1) leading (sum-9)
    | sum > 0        = sum : leading
    | otherwise      = leading

--- How many digits are used for digit sum s
needed = (`div` 9) . (8+)
    
--- convert a reverse digit list to a number
toNum = fold digit 0L . reverse
    where
        digit acc d = acc * 10L + fromInt d

{-- 
    Given a number as a reversed list of digits, 
    find the next number that has the same digit sum.
    
    Interestingly, we do not have to know the actual digit sum to do this!
    
    We must handle the following cases:
    
    - Number of the Form abc, where c is not 0 and b is not 9: The result is
      a(b+1)(c-1)
    - If b and c are both 0, the result is (next a)00
    - If b is 9 or c is 0: We need to "increment" a. If a does not end in 9,
      we can simply add 1 to the last digit and fill the last 2 positions
      with the smallest number that has digit sum b+c-1.
      More generally, we keep taking 9's from the end of a, until we have
      a number to increase, and then fill up so many positions with the
      sum of all taken 9's + b + c - 1
    - 
    -}
next [c] | c > 0     = [c-1, 1]  -- i.e. 7 => 16
         | otherwise = error "next 0 is undefined"
next (c:b:a)
    | c != 0, b != 9 = (c-1) : (b+1) : a
    | c == 0, b == 0 = case next (b:a) of
        (b':a') -> b':0:a'
        _       -> undefined    -- cannot happen 
    | otherwise = incr 2 a (c+b)
    where
        incr k [] s    = smallest k [1] (s-1)
        incr k (9:a) s = incr (k+1) a (s+9)
        incr k (x:a) s = smallest k (x+1:a) (s-1)
next [] = error "need at least one digit"  

fast n = takeWhile (<n) . map toNum $ (iterate next start)
    where
        ds = digitSum n
        start = smallest 0 [] ds

main _ = do
    println (fast 222)
    println (fast 1_000_000)
    println (length (fast 222_222_222))
        